library(rpart)
library(rattle)
## Warning: package 'rattle' was built under R version 4.2.2
## Loading required package: tibble
## Loading required package: bitops
## Rattle: A free graphical interface for data science with R.
## VersiĂ³n 5.5.1 Copyright (c) 2006-2021 Togaware Pty Ltd.
## Escriba 'rattle()' para agitar, sacudir y rotar sus datos.
library(tidyverse)
## ── Attaching packages
## ───────────────────────────────────────
## tidyverse 1.3.2 ──
## ✔ ggplot2 3.3.6 ✔ dplyr 1.0.10
## ✔ tidyr 1.2.1 ✔ stringr 1.4.1
## ✔ readr 2.1.2 ✔ forcats 0.5.2
## ✔ purrr 0.3.4
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
library(plotly)
## Warning: package 'plotly' was built under R version 4.2.2
##
## Attaching package: 'plotly'
##
## The following object is masked from 'package:ggplot2':
##
## last_plot
##
## The following object is masked from 'package:stats':
##
## filter
##
## The following object is masked from 'package:graphics':
##
## layout
Data generating Process
\[ Z = f(x,y) + \epsilon = \sqrt{(x-9)^2 + (y-9)^2} + \epsilon \text{ con } \epsilon \sim \mathcal{N}(0,0.5) \]
#Vamos a crear un dataset sintético y graficarlo en 3D
set.seed(911)
n = 100
dtrain <- data.frame(x = runif(n,4.5,13.5),y = runif(n,4.5,13.5))
noise <- rnorm(n, mean=0, sd=0.5)
dtrain <- dtrain %>% mutate(z = sqrt((x-9)**2+(y-9)**2)+noise)
#VisualizaciĂ³n del dataset sintĂ©tico
plot_ly(dtrain, x = ~x, y = ~y, z = ~z) %>%
add_markers(size = 1,color = I("lightblue"))
ggplot(dtrain, aes(x, y))+
geom_point(color="orange")+
theme_light()
$ Y = f(x) += + $
tree <- rpart(z ~ y + z, data = dtrain, method = "anova",maxdepth = 3, minsplit = 1, minbucket = 1, cp = 0)
## Warning in model.matrix.default(attr(frame, "terms"), frame): the response
## appeared on the right-hand side and was dropped
## Warning in model.matrix.default(attr(frame, "terms"), frame): problem with term
## 2 in model.matrix: no columns are assigned
fancyRpartPlot(tree)
fitted.values <- predict(tree, newdata = dtrain)
frame <- tree$frame
nodevec <- as.numeric(row.names(frame[frame$var == "<leaf>",])) #esto genera un vector con los nĂºmeros de nodos terminales
path.list <- path.rpart(tree, nodes = nodevec) #genera una lista en la cual cada elemento indica el camino a un nodo
##
## node number: 8
## root
## y< 11.74
## y>=6.857
## y>=7.48
##
## node number: 9
## root
## y< 11.74
## y>=6.857
## y< 7.48
##
## node number: 10
## root
## y< 11.74
## y< 6.857
## y>=4.87
##
## node number: 11
## root
## y< 11.74
## y< 6.857
## y< 4.87
##
## node number: 12
## root
## y>=11.74
## y>=11.74
## y< 12.39
##
## node number: 13
## root
## y>=11.74
## y>=11.74
## y>=12.39
##
## node number: 7
## root
## y>=11.74
## y< 11.74
rect_info <- NULL
for(path in path.list){
path <- setdiff(path,"root")
min.x = min(dtrain$x)
max.x = max(dtrain$x)
min.y = min(dtrain$y)
max.y = max(dtrain$y)
for(split in path){
s <- unlist(str_split(split,"< |>="))
var <- s[1]
cutoff <- as.numeric(s[2])
is.less <- str_detect(split,"< ")
if(var == "x1"){
if(is.less == TRUE){
max.x <- cutoff
} else {
min.x <- cutoff
}
} else {
if(is.less == TRUE){
max.y <- cutoff
} else {
min.y <- cutoff
}
}
}
rect_info <- rbind(rect_info,data.frame(xmin = min.x, xmax = max.x, ymin = min.y, ymax = max.y))
}
dtrain <- dtrain %>%
mutate(fitted.values = round(fitted.values,1))
label_points <- dtrain %>%
group_by(fitted.values) %>%
summarise(x = median(x), y = median (y))
ggplot() +
geom_rect(data = rect_info,aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax),colour = "grey50", fill = "white") +
geom_point(data = dtrain,aes(x = x, y = y, color = fitted.values)) +
geom_label(data = label_points,aes(x = x, y = y, label = fitted.values, color=fitted.values)) +
labs(color="Valor ajustado") +
theme_light()